# -*- coding: utf-8 -*-

# Copyright 2011 Tom SF Haines

# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

#   http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.



import exceptions
import cPickle as pickle

import numpy



class Mixture:
  """Defines the basic interface to a mixture model - the methods to train, and then classify basically. Includes funky conversions to allow many forms of input/output, for conveniance."""
  def clusterCount(self):
    """To be implimented by the inheriting class - returns how many elements the mixture has, i.e. the cluster count."""
    raise exceptions.NotImplementedError()

  def parameters(self):
    """To be implimented by the inheriting class - returns how many parameters the model has."""
    raise exceptions.NotImplementedError()
  
    
  def doTrain(self,feats,clusters):
    """To be implimented by the inheriting class, possibly with further implimentation specific parameters. Accepts a data matrix and a cluster count."""
    raise exceptions.NotImplementedError()

  def doGetWeight(self,feats):
    """To be implimented by the inheriting class. Given a data matrix returns a new matrix, where the feature vectors have been replaced by the probabilities of the feature being generated by each of the clusters."""
    raise exceptions.NotImplementedError()

  def doGetCluster(self,feats):
    """To be implimented by the inheriting class. Given a data matrix returns an integer vector, where each feature vector has assigned the cluster from which it most likelly comes."""
    raise exceptions.NotImplementedError()

  def doGetNLL(self,feats):
    """To be implimented by the inheriting class. Given a data matrix returns the negative log likelihood of the data comming from this model."""
    raise exceptions.NotImplementedError()


  def train(self,feats,clusters,*listArgs,**dictArgs):
    """Accepts as input the features, how many clusters to use and any further implimentation specific variables. The features can be represented as a data matrix, a list of vectors or as a list of tuples where the last entry is a feature vector. Will then train the model on the given data set."""
    if isinstance(feats,numpy.ndarray):
      # Numpy array - just pass through...
      assert(len(feats.shape)==2)
      self.doTrain(feats,clusters,*listArgs,**dictArgs)
    elif isinstance(feats,list):
      if isinstance(feats[0],numpy.ndarray):
        # List of vectors - glue them all together to create a data matrix and pass on through...
        data = numpy.vstack(feats)
        self.doTrain(data,clusters,*listArgs,**dictArgs)
      elif isinstance(feats[0],tuple):
        # List of tuples where the last item should be a numpy.array as a feature vector...
        vecs = map(lambda x:x[-1],feats)
        data = numpy.vstack(vecs)
        self.doTrain(data,clusters,*listArgs,**dictArgs)
      else:
        raise exceptions.TypeError('bad type for features - when given a list it must contain numpy.array vectors or tuples with the last element a vector')
    else:
      raise exceptions.TypeError('bad type for features - expects a numpy.array or a list')


  def getWeight(self,feats):
    """Given a set of features returns for each feature the probability of having been generated by each mixture member - this vector will sum to one. Multiple input/output modes are supported. If a data matrix is input then the output will be a matrix where each row has been replaced by the mixing vector. If the input is a list of feature vectors the output will be a list of mixing vectors. If the input is a list of tuples with the last elements a feature vector the output will be identical, but with the feature vectors replaced with mixing vectors."""
    if isinstance(feats,numpy.ndarray):
      # Numpy array - just pass through...
      assert(len(feats.shape)==2)
      return self.doGetWeight(feats)
    elif isinstance(feats,list):
      if isinstance(feats[0],numpy.ndarray):
        # List of vectors - glue them all together to create a data matrix and pass on through...
        data = numpy.vstack(feats)
        asMat = self.doGetWeight(data)
        return map(lambda i:asMat[i,:],xrange(asMat.shape[0]))
      elif isinstance(feats[0],tuple):
        # List of tuples where the last item should be a numpy.array as a feature vector...
        vecs = map(lambda x:x[-1],feats)
        data = numpy.vstack(vecs)
        asMat = self.doGetWeight(data)
        return map(lambda i:feats[i][:-1] + (asMat[i,:],),xrange(asMat.shape[0]))
      else:
        raise exceptions.TypeError('bad type for features - when given a list it must contain numpy.array vectors or tuples with the last element a vector')
    else:
      raise exceptions.TypeError('bad type for features - expects a numpy.array or a list')


  def getCluster(self,feats):
    """Given a set of features returns for each feature the cluster with the highest probability of having generated it. Multiple input/output modes are supported. If a data matrix is input then the output will be an integer vector of cluster indices. If the input is a list of feature vectors the output will be a list of cluster indices. If the input is a list of tuples with the last elements a feature vector the output will be identical, but with the feature vectors replaced by cluster integers."""
    if isinstance(feats,numpy.ndarray):
      # Numpy array - just pass through...
      assert(len(feats.shape)==2)
      return self.doGetCluster(feats)
    elif isinstance(feats,list):
      if isinstance(feats[0],numpy.ndarray):
        # List of vectors - glue them all together to create a data matrix and pass on through...
        data = numpy.vstack(feats)
        asVec = self.doGetCluster(data)
        return map(lambda i:asVec[i],xrange(asVec.shape[0]))
      elif isinstance(feats[0],tuple):
        # List of tuples where the last item should be a numpy.array as a feature vector...
        vecs = map(lambda x:x[-1],feats)
        data = numpy.vstack(vecs)
        asVec = self.doGetCluster(data)
        return map(lambda i:feats[i][:-1] + (asVec[i],),xrange(asVec.shape[0]))
      else:
        raise exceptions.TypeError('bad type for features - when given a list it must contain numpy.array vectors or tuples with the last element a vector')
    else:
      raise exceptions.TypeError('bad type for features - expects a numpy.array or a list')


  def getNLL(self,feats):
    """Given a set of features returns the negative log likelihood of the given features being generated by the model."""
    if isinstance(feats,numpy.ndarray):
      # Numpy array - just pass through...
      assert(len(feats.shape)==2)
      return self.doGetNLL(feats)
    elif isinstance(feats,list):
      if isinstance(feats[0],numpy.ndarray):
        # List of vectors - glue them all together to create a data matrix and pass on through...
        data = numpy.vstack(feats)
        return self.doGetNLL(data)
      elif isinstance(feats[0],tuple):
        # List of tuples where the last item should be a numpy.array as a feature vector...
        vecs = map(lambda x:x[-1],feats)
        data = numpy.vstack(vecs)
        return self.doGetNLL(data)
      else:
        raise exceptions.TypeError('bad type for features - when given a list it must contain numpy.array vectors or tuples with the last element a vector')
    else:
      raise exceptions.TypeError('bad type for features - expects a numpy.array or a list')


  def getData(self):
    """Returns the data contained within, so it can be serialised with other data. (You can of course serialise this class directly if you want, but the returned object is a tuple of numpy arrays, so less likely to be an issue for any program that loads it.)"""
    raise exceptions.NotImplementedError()

  def setData(self,data):
    """Sets the data for the object, should be same form as returned from getData."""
    raise exceptions.NotImplementedError()
